import java.util.List;

public class ListNode {
      int val;
      ListNode next;
      ListNode(int x) { val = x; }
  }
class Solution {
    public ListNode removeElements(ListNode head, int val) {
        while(head.val==val){
            ListNode delNode=head;
            head=delNode.next;
            delNode.next=null;
        }
        if(head==null){
            return null;
        }
        ListNode prev=head;
        while (prev.next!=null) {
            if (prev.next.val==val) {
                ListNode delNode=prev.next;
                prev.next=delNode.next;
                delNode.next=null;
            }
            else{
            removeElements(head, val);
            }
    
            
        }
        return prev;
    }
 

}